// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package mssql

import (
	"fmt"
	"log"
	"time"

	"github.com/hashicorp/go-azure-helpers/lang/pointer"
	"github.com/hashicorp/go-azure-helpers/lang/response"
	"github.com/hashicorp/go-azure-helpers/resourcemanager/commonids"
	"github.com/hashicorp/go-azure-sdk/resource-manager/sql/2023-08-01-preview/serversecurityalertpolicies"
	"github.com/hashicorp/go-azure-sdk/resource-manager/sql/2023-08-01-preview/servervulnerabilityassessments"
	"github.com/hashicorp/terraform-provider-azurerm/helpers/azure"
	"github.com/hashicorp/terraform-provider-azurerm/internal/clients"
	"github.com/hashicorp/terraform-provider-azurerm/internal/services/mssql/parse"
	"github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk"
	"github.com/hashicorp/terraform-provider-azurerm/internal/tf/validation"
	"github.com/hashicorp/terraform-provider-azurerm/internal/timeouts"
)

func resourceMsSqlServerVulnerabilityAssessment() *pluginsdk.Resource {
	return &pluginsdk.Resource{
		Create: resourceMsSqlServerVulnerabilityAssessmentCreate,
		Read:   resourceMsSqlServerVulnerabilityAssessmentRead,
		Update: resourceMsSqlServerVulnerabilityAssessmentUpdate,
		Delete: resourceMsSqlServerVulnerabilityAssessmentDelete,

		Importer: pluginsdk.ImporterValidatingResourceId(func(id string) error {
			_, err := parse.ServerVulnerabilityAssessmentID(id)
			return err
		}),

		Timeouts: &pluginsdk.ResourceTimeout{
			Create: pluginsdk.DefaultTimeout(30 * time.Minute),
			Read:   pluginsdk.DefaultTimeout(5 * time.Minute),
			Update: pluginsdk.DefaultTimeout(30 * time.Minute),
			Delete: pluginsdk.DefaultTimeout(30 * time.Minute),
		},

		Schema: map[string]*pluginsdk.Schema{
			"server_security_alert_policy_id": {
				Type:         pluginsdk.TypeString,
				Required:     true,
				ForceNew:     true,
				ValidateFunc: azure.ValidateResourceID,
			},

			"storage_container_path": {
				Type:         pluginsdk.TypeString,
				Required:     true,
				ValidateFunc: validation.StringIsNotEmpty,
			},

			"recurring_scans": {
				Type:     pluginsdk.TypeList,
				Optional: true,
				Computed: true,
				MaxItems: 1,
				Elem: &pluginsdk.Resource{
					Schema: map[string]*pluginsdk.Schema{
						"email_subscription_admins": {
							Type:     pluginsdk.TypeBool,
							Optional: true,
							Default:  false,
						},

						"emails": {
							Type:     pluginsdk.TypeList,
							Optional: true,
							Elem: &pluginsdk.Schema{
								Type:         pluginsdk.TypeString,
								ValidateFunc: validation.StringIsNotEmpty,
							},
						},

						"enabled": {
							Type:     pluginsdk.TypeBool,
							Optional: true,
							Default:  false,
						},
					},
				},
			},

			"storage_account_access_key": {
				Type:         pluginsdk.TypeString,
				Optional:     true,
				Sensitive:    true,
				ValidateFunc: validation.StringIsNotEmpty,
			},

			"storage_container_sas_key": {
				Type:         pluginsdk.TypeString,
				Optional:     true,
				Sensitive:    true,
				ValidateFunc: validation.StringIsNotEmpty,
			},
		},
	}
}

func resourceMsSqlServerVulnerabilityAssessmentCreate(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient
	alertClient := meta.(*clients.Client).MSSQL.ServerSecurityAlertPoliciesClient
	ctx, cancel := timeouts.ForCreate(meta.(*clients.Client).StopContext, d)
	defer cancel()

	alertId, err := parse.ServerSecurityAlertPolicyID(d.Get("server_security_alert_policy_id").(string))
	if err != nil {
		return err
	}

	serverId := commonids.NewSqlServerID(alertId.SubscriptionId, alertId.ResourceGroup, alertId.ServerName)

	alertResult, err := alertClient.Get(ctx, serverId)
	if err != nil {
		return fmt.Errorf("retrieving security alert policy for SQL server %q: %w", serverId.ServerName, err)
	}

	model := alertResult.Model
	if model == nil {
		return fmt.Errorf("security alert policy for SQL server %q returned empty response", serverId.ServerName)
	}

	alertProps := model.Properties
	if alertProps == nil {
		return fmt.Errorf("security alert policy for SQL server %q has no properties", serverId.ServerName)
	}

	if alertProps.State != serversecurityalertpolicies.SecurityAlertsPolicyStateEnabled {
		return fmt.Errorf("security alert policy for SQL server %q must be enabled before configuring vulnerability assessment", serverId.ServerName)
	}

	log.Printf("[INFO] preparing arguments for mssql server vulnerability assessment creation")

	id := parse.NewServerVulnerabilityAssessmentID(serverId.SubscriptionId, serverId.ResourceGroupName, serverId.ServerName, "default")

	payload := servervulnerabilityassessments.ServerVulnerabilityAssessment{}
	props := &servervulnerabilityassessments.ServerVulnerabilityAssessmentProperties{}

	props.StorageContainerPath = d.Get("storage_container_path").(string)

	if v, ok := d.GetOk("storage_account_access_key"); ok {
		props.StorageAccountAccessKey = pointer.To(v.(string))
	}

	if v, ok := d.GetOk("storage_container_sas_key"); ok {
		props.StorageContainerSasKey = pointer.To(v.(string))
	}

	recurringScanProps := servervulnerabilityassessments.VulnerabilityAssessmentRecurringScansProperties{}

	if v, ok := d.GetOk("recurring_scans"); ok {
		rs := v.([]interface{})

		if len(rs) != 0 {
			v := rs[0].(map[string]interface{})

			var enabled *bool
			if isEnabled, ok := v["enabled"]; ok {
				enabled = pointer.To(isEnabled.(bool))
			}
			recurringScanProps.IsEnabled = enabled

			var emailSubscriptionAdmins *bool
			if emailAdmins, ok := v["email_subscription_admins"]; ok {
				emailSubscriptionAdmins = pointer.To(emailAdmins.(bool))
			}
			recurringScanProps.EmailSubscriptionAdmins = emailSubscriptionAdmins

			var emails *[]string
			if _, ok := v["emails"]; ok {
				config := make([]string, 0)
				for _, email := range v["emails"].([]interface{}) {
					config = append(config, email.(string))
				}
				emails = pointer.To(config)
			}
			recurringScanProps.Emails = emails
		}
	}
	props.RecurringScans = pointer.To(recurringScanProps)

	payload.Properties = props

	result, err := client.CreateOrUpdate(ctx, serverId, payload)
	if err != nil || result.Model == nil || result.Model.Id == nil {
		return fmt.Errorf("creating vulnerability assessment for SQL server %q: %w", serverId.ServerName, err)
	}

	d.SetId(id.ID())

	return resourceMsSqlServerVulnerabilityAssessmentRead(d, meta)
}

func resourceMsSqlServerVulnerabilityAssessmentRead(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient
	alertClient := meta.(*clients.Client).MSSQL.ServerSecurityAlertPoliciesClient
	ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
	defer cancel()

	log.Printf("[INFO] reading mssql server vulnerability assessment")

	id, err := parse.ServerVulnerabilityAssessmentID(d.Id())
	if err != nil {
		return err
	}

	serverId := commonids.NewSqlServerID(id.SubscriptionId, id.ResourceGroup, id.ServerName)

	vulnerability, err := client.Get(ctx, serverId)
	if err != nil {
		if response.WasNotFound(vulnerability.HttpResponse) {
			log.Printf("[WARN] vulnerability assessment for SQL server %q was not found", serverId.ServerName)
			d.SetId("")
			return nil
		}

		return fmt.Errorf("retrieving vulnerability assessment for SQL server %q: %w", serverId.ServerName, err)
	}

	alert, err := alertClient.Get(ctx, serverId)
	if err != nil {
		return fmt.Errorf("retrieving security alert policy for SQL server %q: %w", serverId.ServerName, err)
	}

	model := alert.Model

	if model == nil {
		return fmt.Errorf("security alert policy for SQL server %q returned empty response", serverId.ServerName)
	}

	if model.Id == nil {
		return fmt.Errorf("security alert policy for SQL server %q has no resource ID", serverId.ServerName)
	}

	recurringScans := make([]interface{}, 0)
	var storageContainerPath string
	var storageAccountAccessKey string
	var storageContainerSasKey string
	recurringScansResult := make(map[string]interface{})

	if vModel := vulnerability.Model; vModel != nil {
		if props := vModel.Properties; props != nil {
			storageContainerPath = props.StorageContainerPath

			if recurringScansProps := props.RecurringScans; recurringScansProps != nil {
				var enabled bool
				if recurringScansProps.IsEnabled != nil {
					enabled = *recurringScansProps.IsEnabled
				}
				recurringScansResult["enabled"] = enabled

				var emailAdmins bool
				if recurringScansProps.EmailSubscriptionAdmins != nil {
					emailAdmins = *recurringScansProps.EmailSubscriptionAdmins
				}
				recurringScansResult["email_subscription_admins"] = emailAdmins

				var emails []string
				if recurringScansProps.Emails != nil {
					emails = *recurringScansProps.Emails
				}
				recurringScansResult["emails"] = emails
			}

			recurringScans = []interface{}{recurringScansResult}
		}
	}

	d.Set("server_security_alert_policy_id", model.Id)
	d.Set("storage_container_path", storageContainerPath)
	d.Set("recurring_scans", recurringScans)

	if v, ok := d.GetOk("storage_account_access_key"); ok {
		storageAccountAccessKey = v.(string)
	}
	d.Set("storage_account_access_key", storageAccountAccessKey)

	if v, ok := d.GetOk("storage_container_sas_key"); ok {
		storageContainerSasKey = v.(string)
	}
	d.Set("storage_container_sas_key", storageContainerSasKey)

	return nil
}

func resourceMsSqlServerVulnerabilityAssessmentUpdate(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient
	alertClient := meta.(*clients.Client).MSSQL.ServerSecurityAlertPoliciesClient
	ctx, cancel := timeouts.ForUpdate(meta.(*clients.Client).StopContext, d)
	defer cancel()

	id, err := parse.ServerVulnerabilityAssessmentID(d.Id())
	if err != nil {
		return err
	}

	serverId := commonids.NewSqlServerID(id.SubscriptionId, id.ResourceGroup, id.ServerName)

	alert, err := alertClient.Get(ctx, serverId)
	if err != nil {
		return fmt.Errorf("retrieving security alert policy for SQL server %q: %w", serverId.ServerName, err)
	}

	if alert.Model == nil {
		return fmt.Errorf("security alert policy for SQL server %q returned empty response", serverId.ServerName)
	}

	if alert.Model.Properties == nil {
		return fmt.Errorf("security alert policy for SQL server %q has no properties", serverId.ServerName)
	}

	if alert.Model.Properties.State != serversecurityalertpolicies.SecurityAlertsPolicyStateEnabled {
		return fmt.Errorf("security alert policy for SQL server %q must be enabled before updating vulnerability assessment", serverId.ServerName)
	}

	log.Printf("[INFO] preparing arguments for mssql server vulnerability assessment update")

	payload := servervulnerabilityassessments.ServerVulnerabilityAssessment{}
	result, err := client.Get(ctx, serverId)
	if err != nil {
		return fmt.Errorf("retrieving existing vulnerability assessment for SQL server %q: %w", serverId.ServerName, err)
	}

	if result.Model == nil {
		return fmt.Errorf("vulnerability assessment for SQL server %q returned empty response", serverId.ServerName)
	}

	props := result.Model.Properties

	if props == nil {
		return fmt.Errorf("vulnerability assessment for SQL server %q has no properties", serverId.ServerName)
	}

	if d.HasChange("storage_container_path") {
		props.StorageContainerPath = d.Get("storage_container_path").(string)
	}

	if d.HasChange("recurring_scans") {
		var isEnabled *bool
		var emailSubscriptionAdmins *bool
		emails := make([]string, 0)

		recurringProps := servervulnerabilityassessments.VulnerabilityAssessmentRecurringScansProperties{
			EmailSubscriptionAdmins: emailSubscriptionAdmins,
			Emails:                  pointer.To(emails),
			IsEnabled:               isEnabled,
		}

		if rs, ok := d.GetOk("recurring_scans"); ok {
			recurringScans := rs.([]interface{})

			if len(recurringScans) != 0 {
				v := recurringScans[0].(map[string]interface{})

				if enabled, ok := v["enabled"]; ok {
					isEnabled = pointer.To(enabled.(bool))
				}
				recurringProps.IsEnabled = isEnabled

				if emailAdmins, ok := v["email_subscription_admins"]; ok {
					emailSubscriptionAdmins = pointer.To(emailAdmins.(bool))
				}
				recurringProps.EmailSubscriptionAdmins = emailSubscriptionAdmins

				if _, ok := v["emails"]; ok {
					for _, email := range v["emails"].([]interface{}) {
						emails = append(emails, email.(string))
					}
				}
				recurringProps.Emails = pointer.To(emails)
			}
		}

		props.RecurringScans = pointer.To(recurringProps)
	}

	// NOTE: 'storage_account_access_key' and 'storage_container_sas_key'
	// are not returned by the API...
	if v, ok := d.GetOk("storage_account_access_key"); ok {
		props.StorageAccountAccessKey = pointer.To(v.(string))
	}

	if v, ok := d.GetOk("storage_container_sas_key"); ok {
		props.StorageContainerSasKey = pointer.To(v.(string))
	}

	payload.Properties = props

	update, err := client.CreateOrUpdate(ctx, serverId, payload)
	if err != nil || update.Model == nil || update.Model.Id == nil {
		return fmt.Errorf("updating vulnerability assessment for SQL server %q: %w", serverId.ServerName, err)
	}

	return resourceMsSqlServerVulnerabilityAssessmentRead(d, meta)
}

func resourceMsSqlServerVulnerabilityAssessmentDelete(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient
	ctx, cancel := timeouts.ForDelete(meta.(*clients.Client).StopContext, d)
	defer cancel()

	log.Printf("[INFO] deleting mssql server vulnerability assessment")

	id, err := parse.ServerVulnerabilityAssessmentID(d.Id())
	if err != nil {
		return err
	}

	serverId := commonids.NewSqlServerID(id.SubscriptionId, id.ResourceGroup, id.ServerName)

	if _, err = client.Delete(ctx, serverId); err != nil {
		return fmt.Errorf("deleting vulnerability assessment for SQL server %q: %w", serverId.ServerName, err)
	}

	return nil
}
